// Copyright 2000-2022 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.

package org.jetbrains.kotlin.idea.inspections.coroutines

import com.intellij.codeInsight.FileModificationService
import com.intellij.codeInspection.IntentionWrapper
import com.intellij.codeInspection.LocalQuickFix
import com.intellij.codeInspection.ProblemDescriptor
import com.intellij.codeInspection.ProblemHighlightType.GENERIC_ERROR_OR_WARNING
import com.intellij.codeInspection.ProblemsHolder
import com.intellij.openapi.application.runWriteAction
import com.intellij.openapi.project.Project
import com.intellij.psi.PsiElementVisitor
import org.jetbrains.kotlin.descriptors.DeclarationDescriptor
import org.jetbrains.kotlin.descriptors.ReceiverParameterDescriptor
import org.jetbrains.kotlin.idea.base.resources.KotlinBundle
import org.jetbrains.kotlin.idea.caches.resolve.analyzeWithContent
import org.jetbrains.kotlin.idea.caches.resolve.resolveToDescriptorIfAny
import org.jetbrains.kotlin.idea.core.ShortenReferences
import org.jetbrains.kotlin.idea.base.psi.replaced
import org.jetbrains.kotlin.idea.base.util.reformatted
import org.jetbrains.kotlin.idea.codeinsight.api.classic.inspections.AbstractKotlinInspection
import org.jetbrains.kotlin.idea.inspections.UnusedReceiverParameterInspection
import org.jetbrains.kotlin.idea.intentions.ConvertReceiverToParameterIntention
import org.jetbrains.kotlin.idea.intentions.MoveMemberToCompanionObjectIntention
import org.jetbrains.kotlin.idea.util.application.executeWriteCommand
import org.jetbrains.kotlin.lexer.KtTokens
import org.jetbrains.kotlin.psi.*
import org.jetbrains.kotlin.psi.psiUtil.containingClassOrObject
import org.jetbrains.kotlin.psi.psiUtil.forEachDescendantOfType
import org.jetbrains.kotlin.psi.psiUtil.getNonStrictParentOfType
import org.jetbrains.kotlin.resolve.BindingContext
import org.jetbrains.kotlin.resolve.calls.util.getResolvedCall
import org.jetbrains.kotlin.resolve.descriptorUtil.fqNameSafe
import org.jetbrains.kotlin.resolve.scopes.receivers.ExpressionReceiver
import org.jetbrains.kotlin.resolve.scopes.receivers.ImplicitReceiver
import org.jetbrains.kotlin.resolve.scopes.receivers.ReceiverValue
import org.jetbrains.kotlin.types.KotlinType

class SuspendFunctionOnCoroutineScopeInspection : AbstractKotlinInspection() {

    override fun buildVisitor(holder: ProblemsHolder, isOnTheFly: Boolean): PsiElementVisitor {
        return namedFunctionVisitor(fun(function: KtNamedFunction) {
            if (!function.hasModifier(KtTokens.SUSPEND_KEYWORD)) return
            if (!function.hasBody()) return

            val context = function.analyzeWithContent()
            val descriptor = context[BindingContext.FUNCTION, function] ?: return
            val (extensionOfCoroutineScope, memberOfCoroutineScope) = with(descriptor) {
                extensionReceiverParameter.ofCoroutineScopeType() to dispatchReceiverParameter.ofCoroutineScopeType()
            }
            if (!extensionOfCoroutineScope && !memberOfCoroutineScope) return

            fun DeclarationDescriptor.isReceiverOfAnalyzedFunction(): Boolean {
                if (extensionOfCoroutineScope && this == descriptor) return true
                if (memberOfCoroutineScope && this == descriptor.containingDeclaration) return true
                return false
            }

            fun checkSuspiciousReceiver(receiver: ReceiverValue, problemExpression: KtExpression) {
                when (receiver) {
                    is ImplicitReceiver -> if (!receiver.declarationDescriptor.isReceiverOfAnalyzedFunction()) return
                    is ExpressionReceiver -> {
                        val receiverThisExpression = receiver.expression as? KtThisExpression ?: return
                        if (receiverThisExpression.getTargetLabel() != null) {
                            val instanceReference = receiverThisExpression.instanceReference
                            if (context[BindingContext.REFERENCE_TARGET, instanceReference]?.isReceiverOfAnalyzedFunction() != true) return
                        }
                    }
                }
                val fixes = mutableListOf<LocalQuickFix>()
                val reportElement = (problemExpression as? KtCallExpression)?.calleeExpression ?: problemExpression
                holder.registerProblem(
                    reportElement,
                    MESSAGE,
                    GENERIC_ERROR_OR_WARNING,
                    WrapWithCoroutineScopeFix(removeReceiver = false, wrapCallOnly = true)
                )
                fixes += WrapWithCoroutineScopeFix(removeReceiver = extensionOfCoroutineScope, wrapCallOnly = false)
                if (extensionOfCoroutineScope) {
                    fixes += IntentionWrapper(ConvertReceiverToParameterIntention())
                }
                if (memberOfCoroutineScope) {
                    val containingDeclaration = function.containingClassOrObject
                    if (containingDeclaration is KtClass && !containingDeclaration.isInterface() && function.hasBody()) {
                        fixes += IntentionWrapper(MoveMemberToCompanionObjectIntention())
                    }
                }

                holder.registerProblem(
                    with(function) { receiverTypeReference ?: nameIdentifier ?: funKeyword ?: this },
                    MESSAGE,
                    GENERIC_ERROR_OR_WARNING,
                    *fixes.toTypedArray()
                )
            }

            function.forEachDescendantOfType(fun(callExpression: KtCallExpression) {
                val resolvedCall = callExpression.getResolvedCall(context) ?: return
                val extensionReceiverParameter = resolvedCall.resultingDescriptor.extensionReceiverParameter ?: return
                if (!extensionReceiverParameter.type.isCoroutineScope()) return
                val extensionReceiver = resolvedCall.extensionReceiver ?: return
                checkSuspiciousReceiver(extensionReceiver, callExpression)
            })
            function.forEachDescendantOfType(fun(nameReferenceExpression: KtNameReferenceExpression) {
                if (nameReferenceExpression.getReferencedName() != COROUTINE_CONTEXT) return
                val resolvedCall = nameReferenceExpression.getResolvedCall(context) ?: return
                if (resolvedCall.resultingDescriptor.fqNameSafe.asString() == "$COROUTINE_SCOPE.$COROUTINE_CONTEXT") {
                    val dispatchReceiver = resolvedCall.dispatchReceiver ?: return
                    checkSuspiciousReceiver(dispatchReceiver, nameReferenceExpression)
                }
            })
        })
    }

    private class WrapWithCoroutineScopeFix(
        private val removeReceiver: Boolean,
        private val wrapCallOnly: Boolean
    ) : LocalQuickFix {
        override fun getFamilyName(): String = KotlinBundle.message("wrap.with.coroutine.scope.fix.family.name")

        override fun getName(): String =
            when {
                removeReceiver && !wrapCallOnly -> KotlinBundle.message("wrap.with.coroutine.scope.fix.text3")
                wrapCallOnly -> KotlinBundle.message("wrap.with.coroutine.scope.fix.text2")
                else -> KotlinBundle.message("wrap.with.coroutine.scope.fix.text")
            }

        override fun startInWriteAction() = false

        override fun applyFix(project: Project, descriptor: ProblemDescriptor) {
            val problemElement = descriptor.psiElement ?: return
            val function = problemElement.getNonStrictParentOfType<KtNamedFunction>() ?: return
            val functionDescriptor = function.resolveToDescriptorIfAny()
            if (!FileModificationService.getInstance().preparePsiElementForWrite(function)) return
            val bodyExpression = function.bodyExpression

            fun getExpressionToWrapCall(): KtExpression? {
                var result = problemElement as? KtExpression ?: return null
                while (result.parent is KtQualifiedExpression || result.parent is KtCallExpression) {
                    result = result.parent as KtExpression
                }
                return result
            }

            var expressionToWrap = when {
                wrapCallOnly -> getExpressionToWrapCall()
                else -> bodyExpression
            } ?: return
            if (functionDescriptor?.extensionReceiverParameter.ofCoroutineScopeType()) {
                val context = function.analyzeWithContent()
                expressionToWrap.forEachDescendantOfType<KtDotQualifiedExpression> {
                    val receiverExpression = it.receiverExpression as? KtThisExpression
                    val selectorExpression = it.selectorExpression
                    if (receiverExpression?.getTargetLabel() != null && selectorExpression != null) {
                        if (context[BindingContext.REFERENCE_TARGET, receiverExpression.instanceReference] == functionDescriptor) {
                            runWriteAction {
                                if (it === expressionToWrap) {
                                    expressionToWrap = it.replaced(selectorExpression)
                                } else {
                                    it.replace(selectorExpression)
                                }
                            }
                        }
                    }
                }
            }

            val psiFactory = KtPsiFactory(project)
            val blockExpression = function.bodyBlockExpression
            project.executeWriteCommand(name, this) {
                val result = when {
                    expressionToWrap != bodyExpression -> expressionToWrap.replaced(
                        psiFactory.createExpressionByPattern("$COROUTINE_SCOPE_WRAPPER { $0 }", expressionToWrap)
                    )
                    blockExpression == null -> bodyExpression.replaced(
                        psiFactory.createExpressionByPattern("$COROUTINE_SCOPE_WRAPPER { $0 }", bodyExpression)
                    )
                    else -> {
                        val bodyText = buildString {
                            for (statement in blockExpression.statements) {
                                append(statement.text)
                                append("\n")
                            }
                        }
                        blockExpression.replaced(
                            psiFactory.createBlock("$COROUTINE_SCOPE_WRAPPER { $bodyText }")
                        )
                    }
                }
                ShortenReferences.DEFAULT.process(result.reformatted() as KtElement)
            }

            val receiverTypeReference = function.receiverTypeReference
            if (removeReceiver && !wrapCallOnly && receiverTypeReference != null) {
                UnusedReceiverParameterInspection.RemoveReceiverFix.apply(receiverTypeReference, project)
            }
        }
    }
}

private const val COROUTINE_SCOPE = "kotlinx.coroutines.CoroutineScope"

private const val COROUTINE_SCOPE_WRAPPER = "kotlinx.coroutines.coroutineScope"

private const val COROUTINE_CONTEXT = "coroutineContext"

private val MESSAGE get() = KotlinBundle.message("ambiguous.coroutinecontext.due.to.coroutinescope.receiver.of.suspend.function")

private fun KotlinType.isCoroutineScope(): Boolean =
    constructor.declarationDescriptor?.fqNameSafe?.asString() == COROUTINE_SCOPE

private fun ReceiverParameterDescriptor?.ofCoroutineScopeType(): Boolean {
    if (this == null) return false
    if (type.isCoroutineScope()) return true
    return type.constructor.supertypes.reversed().any { it.isCoroutineScope() }
}
